import itertools
from collections import Counter
from pylab import *


timepoints = array([0, 1, 4, 12, 24, 96])

conditions = {"StartSeq": {"SRR7071452": "6 hr",
                           "SRR7071453": "6 hr",
                          },
              "MiSeq": {"t00_r1": "0 hr",
                        "t00_r2": "0 hr",
                        "t00_r3": "0 hr",
                        "t01_r1": "1 hr",
                        "t01_r2": "1 hr",
                        "t01_r3": "1 hr",
                        "t04_r1": "4 hr",
                        "t04_r2": "4 hr",
                        "t04_r3": "4 hr",
                        "t12_r1": "12 hr",
                        "t12_r2": "12 hr",
                        "t12_r3": "12 hr",
                        "t24_r1": "24 hr",
                        "t24_r2": "24 hr",
                        "t24_r3": "24 hr",
                        "t96_r1": "96 hr",
                        "t96_r2": "96 hr",
                        "t96_r3": "96 hr",
                       },
              "HiSeq": {"t00_r1": "0 hr",
                        "t00_r2": "0 hr",
                        "t00_r3": "0 hr",
                        "t01_r1": "1 hr",
                        "t01_r2": "1 hr",
                        "t04_r1": "4 hr",
                        "t04_r2": "4 hr",
                        "t04_r3": "4 hr",
                        "t12_r1": "12 hr",
                        "t12_r2": "12 hr",
                        "t12_r3": "12 hr",
                        "t24_r1": "24 hr",
                        "t24_r2": "24 hr",
                        "t24_r3": "24 hr",
                        "t96_r1": "96 hr",
                        "t96_r2": "96 hr",
                        "t96_r3": "96 hr",
                       },
              "CAGE": {"00_hr_A": "0 hr",
                       "00_hr_C": "0 hr",
                       "00_hr_G": "0 hr",
                       "00_hr_H": "0 hr",
                       "01_hr_A": "1 hr",
                       "01_hr_C": "1 hr",
                       "01_hr_G": "1 hr",
                       "04_hr_C": "4 hr",
                       "04_hr_E": "4 hr",
                       "12_hr_A": "12 hr",
                       "12_hr_C": "12 hr",
                       "24_hr_C": "24 hr",
                       "24_hr_E": "24 hr",
                       "96_hr_A": "96 hr",
                       "96_hr_C": "96 hr",
                       "96_hr_E": "96 hr",
                      },
             }

def read_annotation_counts(dataset):
    filename = "annotations.%s.txt" % dataset
    print("Reading", filename)
    handle = open(filename)
    line = next(handle)
    line = line.strip()
    assert line.startswith("#")
    words = line[1:].split("\t")
    assert words[0] == 'rank'
    assert words[1] == 'annotation'
    assert words[2] == 'transcript'
    libraries = words[3:]
    counts = {}
    for library in conditions[dataset]:
        counts[library] = Counter()
        assert library in libraries
    for number, line in enumerate(handle):
        words = line.strip().split("\t")
        assert int(words[0]) == number
        annotation = words[1].strip()
        if annotation in ['sense_proximal', "sense_distal",
                          "sense_upstream", 'sense_distal_upstream']:
            annotation = "sense"
        elif annotation in ["antisense", "antisense_distal",
                            "prompt", "antisense_distal_upstream"]:
            annotation = "antisense"
        transcript = words[2].strip()
        values = words[3:]
        assert len(values) == len(libraries)
        for library, value in zip(libraries, values):
            if library in counts:
                counts[library][annotation] += int(value)
    handle.close()
    return counts

def summarize(dataset, library):
    total_count = 0
    gene_count = 0
    skip_count = 0
    short_polyAminus_RNA_count = 0
    fantom5_enhancer_count = 0
    intergenic_count = 0
    gene_associated_RNAs = ['sense', "antisense", "histone"]
    intergenic_RNAs = ["FANTOM5_enhancer", "roadmap_dyadic", "roadmap_enhancer",
                       "novel_enhancer_HiSeq", "novel_enhancer_CAGE",
                       "other_intergenic"]
    for annotation in counts[dataset][library]:
        count = counts[dataset][library][annotation]
        if annotation in ["unmapped", "chrM", 'rRNA']:
            skip_count += count
        elif annotation in ("Pol-II short RNA",
                            "Pol-III short RNA",
                            "intronic short RNA"):
            short_polyAminus_RNA_count += count
        elif annotation in gene_associated_RNAs:
            gene_count += count
        elif annotation in intergenic_RNAs:
            intergenic_count += count
            if annotation == "FANTOM5_enhancer":
                fantom5_enhancer_count += count
        elif annotation == "short RNA precursor":
            pass
        else:
            raise Exception("Unknown annotation %s" % annotation)
        total_count += count
    print("%s, %s: %.2f%% were unmapped or mapped to rRNA or chrM" % (dataset, library, skip_count*100.0/total_count))
    total_count -= skip_count
    print("%s, %s: %.2f%% of remaining reads mapped to known small poly(A)minus RNAs" % (dataset, library, short_polyAminus_RNA_count*100.0/total_count))
    print("%s, %s: %.2f%% of remaining reads were associated with genes" % (dataset, library, gene_count*100.0/total_count))
    print("%s, %s: %.2f%% of intergenic reads mapped to FANTOM5 enhancers" % (dataset, library, fantom5_enhancer_count*100.0/intergenic_count))

colors = {"unmapped": "black",
          "chrM": "gray",
          "rRNA": "darkgray",
          "Pol-II short RNA": "forestgreen",
          "Pol-III short RNA": "limegreen",
          "intronic short RNA": "mediumspringgreen",
          "short RNA precursor": "deepskyblue",
          "histone": "navy",
          "sense": "indianred",
          "antisense": "maroon",
          "FANTOM5_enhancer": "gold",
          "roadmap_enhancer": "goldenrod",
          "roadmap_dyadic": "khaki",
          "novel_enhancer_CAGE": "darkgoldenrod",
          "novel_enhancer_HiSeq": "orange",
          "other_intergenic": "whitesmoke",
         }

annotations = tuple(colors.keys())

counts = {}
datasets = []
libraries = []
seen = set()
for dataset in ("StartSeq", "MiSeq", "HiSeq", "CAGE"):
    counts[dataset] = read_annotation_counts(dataset)
    for library in counts[dataset]:
        datasets.append(dataset)
        libraries.append(library)
        seen.update(counts[dataset][library].keys())

assert seen == colors.keys()

m = len(libraries)
assert len(datasets) == m
n = len(colors)
data = zeros((n, m))

filename = "table_annotations_timecourse.txt"
print("Writing %s" % filename)
stream = open(filename, 'wt')
row1 = ["dataset"]
row2 = ["library"]
row3 = ["condition"]
for dataset in conditions:
    for library in conditions[dataset]:
        condition = conditions[dataset][library]
        row1.append(dataset)
        row2.append(library)
        row3.append(condition)
line = "\t".join(row1) + "\n"
stream.write(line)
line = "\t".join(row2) + "\n"
stream.write(line)
line = "\t".join(row3) + "\n"
stream.write(line)
for i, annotation in enumerate(annotations):
    row = [annotation]
    for j, (dataset, library) in enumerate(zip(datasets, libraries)):
        count = counts[dataset][library].get(annotation, 0)
        row.append(str(count))
        data[i, j] = count
    line = "\t".join(row) + "\n"
    stream.write(line)
stream.close()

data /= sum(data, 0)
data *= 100

x = []
xx = -1.0 + 0.5
current_dataset = None
locations = {}
for dataset, library in zip(datasets, libraries):
    if dataset != current_dataset:
        xx += 1
        locations[dataset] = {}
        current_dataset = dataset
    locations[dataset][library] = xx
    x.append(xx)
    xx += 1

f = figure(figsize=(8,7))

left = 0.10
right = 0.75
bottom2 = 0.38
top2 = 0.86
width = right - left
height = top2 - bottom2
ax = f.add_axes([left, bottom2, width, height])

bottom = zeros(m)
for annotation, row in zip(annotations, data):
    color = colors[annotation]
    bar(x, row, width=1.0, bottom=bottom, color=color, label=annotation)
    bottom += row

position = left - 0.09
f.text(position, top2, "a", fontsize=16, horizontalalignment='left', verticalalignment='top')
position = right + 0.01
f.text(position, top2, "All reads", fontsize=8, fontweight='bold', horizontalalignment='left', verticalalignment='top')

xticks([])
ylabel("Percentage", fontsize=8)
yticks(fontsize=8)
ylim(0, 100)
xmax = max(x) + 0.5
xlim(0, xmax)

bottom3 = 0.20
top3 = 0.35
width = right - left
height = top3 - bottom3
f.add_axes([left, bottom3, width, height])

selected_annotations = ("sense",
                        "antisense",
                        "FANTOM5_enhancer",
                        "roadmap_enhancer",
                        "roadmap_dyadic",
                        "novel_enhancer_CAGE",
                        "novel_enhancer_HiSeq",
                        "other_intergenic",
                       )
n = len(selected_annotations)
data = zeros((n, m))
for i, annotation in enumerate(selected_annotations):
    for j, (dataset, library) in enumerate(zip(datasets, libraries)):
        data[i, j] = counts[dataset][library].get(annotation, 0)

data /= sum(data, 0)
data *= 100

percentage_gene_associated = {}
for dataset in datasets:
    percentage_gene_associated[dataset] = []
for j, (dataset, library) in enumerate(zip(datasets, libraries)):
    percentage_gene_associated[dataset].append(sum(data[:2, j]))

print("Panel B, percentage of reads associated with genes:")
for dataset in percentage_gene_associated:
    percentage = mean(percentage_gene_associated[dataset])
    print("%s: %.2f%%" % (dataset, percentage))


bottom = zeros(m)
for annotation, row in zip(selected_annotations, data):
    color = colors[annotation]
    bar(x, row, width=1.0, bottom=bottom, color=color, label=annotation)
    bottom += row

position = left - 0.09
f.text(position, top3, "b", fontsize=16, horizontalalignment='left', verticalalignment='top')
position = right + 0.01
f.text(position, top3, "Reads associated with mRNA\nor lncRNA genes in the sense\nor antisense orientation,\nand intergenic reads", fontsize=8, fontweight='bold', horizontalalignment='left', verticalalignment='top')

xticks([])
ylabel("Percentage", fontsize=8)
yticks(fontsize=8)
ylim(0, 100)
xmax = max(x) + 0.5
xlim(0, xmax)

bottom4 = 0.02
top4 = 0.17
width = right - left
height = top4 - bottom4
f.add_axes([left, bottom4, width, height])

intergenic_annotations = ("FANTOM5_enhancer",
                          "roadmap_enhancer",
                          "roadmap_dyadic",
                          "novel_enhancer_CAGE",
                          "novel_enhancer_HiSeq",
                          "other_intergenic",
                         )
n = len(intergenic_annotations)
data = zeros((n, m))
for i, annotation in enumerate(intergenic_annotations):
    for j, (dataset, library) in enumerate(zip(datasets, libraries)):
        data[i, j] = counts[dataset][library].get(annotation, 0)

data /= sum(data, 0)
data *= 100

bottom = zeros(m)
for annotation, row in zip(intergenic_annotations, data):
    color = colors[annotation]
    bar(x, row, width=1.0, bottom=bottom, color=color, label=annotation)
    bottom += row

position = left - 0.09
f.text(position, top4, "c", fontsize=16, horizontalalignment='left', verticalalignment='top')
position = right + 0.01
f.text(position, top4, "Intergenic reads only", fontsize=8, fontweight='bold', horizontalalignment='left', verticalalignment='top')

xticks([])
ylabel("Percentage", fontsize=8)
yticks(fontsize=8)
ylim(0, 100)
xmax = max(x) + 0.5
xlim(0, xmax)

bottom1 = 0.89
top1 = 0.92
width = right - left
height = top1 - bottom1
ax2 = f.add_axes([left, bottom1, width, height])

ax2.spines['top'].set_color('none')
ax2.spines['bottom'].set_color('none')
ax2.spines['left'].set_color('none')
ax2.spines['right'].set_color('none')
ax2.tick_params(labelcolor='w', top=False, bottom=False, left=False, right=False)

color = []
current_dataset = None
current_timepoint = None
xx = -1
positions = []
labels = []
for (dataset, library) in zip(datasets, libraries):
    if dataset in "StartSeq":
        assert library in ("SRR7071452", "SRR7071453")
        timepoint = 6
    elif dataset in ("MiSeq", "HiSeq"):
        assert library.startswith("t")
        assert library[-3:] in ("_r1", "_r2", "_r3")
        timepoint = int(library[1:3])
    elif dataset == "CAGE":
        assert library[-5:] in ("_hr_A", "_hr_C", "_hr_E", "_hr_G", "_hr_H")
        timepoint = int(library[:2])
    else:
        raise Exception("Unknown dataset %s" % dataset)
    if timepoint != current_timepoint:
        if current_timepoint is not None:
            position = (start + xx) / 2.0
            positions.append(position)
            labels.append(str(current_timepoint))
        current_timepoint = timepoint
        index = interp(timepoint, timepoints, arange(6))
        if dataset != current_dataset:
            current_dataset = dataset
            xx += 1
        start = xx
    color.append(cm.Blues(index/10))
    xx += 1
position = (start + xx) / 2.0
positions.append(position)
labels.append(str(current_timepoint))
assert len(color) == m
bar(x, ones(m), width=1.0, color=color)
for position, label in zip(positions, labels):
    text(position, 0.5, label, horizontalalignment='center', verticalalignment='center', fontsize=10, color='black')
xticks([])
yticks([])
xlim(0, xmax)
ylim(0, 1)

for dataset in locations:
    positions = [locations[dataset][library] for library in locations[dataset]]
    position = mean(positions)
    position = left + (right - left) * (position / xmax)
    if dataset == "StartSeq":
        rna = "Start-Seq"
        library = ""
    elif dataset == "MiSeq":
        rna = "Short capped RNAs,"
        library = "paired-end libraries"
    elif dataset == "HiSeq":
        rna = "Short capped RNAs,"
        library = "single-end libraries"
    elif dataset == "CAGE":
        rna = "Long capped RNAs,"
        library = "CAGE libraries"
    else:
        raise Exception("Unknown dataset %s" % dataset)
    f.text(position, top1+0.04, rna, fontsize=8, horizontalalignment='center')
    f.text(position, top1+0.02, library, fontsize=8, horizontalalignment='center')
position = right + 0.01
f.text(position, (bottom1+top1)/2, "Time point [hr]", fontsize=8, fontweight='bold', horizontalalignment='left', verticalalignment='center')
f.text(position, top2-0.14, "Annotations:", fontsize=8, fontweight='bold', horizontalalignment='left', verticalalignment='top')

handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], bbox_to_anchor=(1.0,0.), loc='lower left', fontsize=8, labelspacing=0.25,frameon=False)


filename = "figure_annotations_timecourse.png"
print("Saving figure as %s" % filename)
savefig(filename)
filename = "figure_annotations_timecourse.svg"
print("Saving figure as %s" % filename)
savefig(filename)
